Source code for hysop.operator.base.spectral_operator

# Copyright (c) HySoP 2011-2024
#
# This file is part of HySoP software.
# See "https://particle_methods.gricad-pages.univ-grenoble-alpes.fr/hysop-doc/"
# for further info.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import warnings
import math
import os
import sympy as sm
import numpy as np

from hysop.constants import (
    BoundaryCondition,
    BoundaryExtension,
    TransformType,
    MemoryOrdering,
    TranspositionState,
    Backend,
    SpectralTransformAction,
    Implementation,
)
from hysop.tools.misc import compute_nbytes
from hysop.tools.htypes import check_instance, to_tuple, first_not_None, to_set
from hysop.tools.decorators import debug
from hysop.tools.units import bytes2str
from hysop.tools.numerics import (
    is_fp,
    is_complex,
    complex_to_float_dtype,
    float_to_complex_dtype,
    determine_fp_types,
)
from hysop.tools.io_utils import IOParams
from hysop.tools.spectral_utils import (
    SpectralTransformUtils as STU,
    EnergyPlotter,
    EnergyDumper,
)
from hysop.core.arrays.array_backend import ArrayBackend
from hysop.core.arrays.array import Array
from hysop.core.memory.memory_request import MemoryRequest, OperatorMemoryRequests
from hysop.core.graph.graph import (
    not_initialized as _not_initialized,
    initialized as _initialized,
    discretized as _discretized,
    ready as _ready,
)
from hysop.core.graph.computational_node_frontend import ComputationalGraphNodeFrontend
from hysop.topology.cartesian_descriptor import CartesianTopologyDescriptors
from hysop.parameters.buffer_parameter import BufferParameter
from hysop.fields.continuous_field import Field, ScalarField, TensorField
from hysop.symbolic.array import SymbolicArray
from hysop.symbolic.spectral import (
    WaveNumber,
    SpectralTransform,
    AppliedSpectralTransform,
)
from hysop.numerics.fft.fft import (
    FFTI,
    simd_alignment,
    is_byte_aligned,
    HysopFFTWarning,
)


[docs] class SpectralComputationalGraphNodeFrontend(ComputationalGraphNodeFrontend): def __new__(cls, implementation, enforce_implementation=True, **kwds): return super().__new__(cls, implementation=implementation, **kwds) def __init__(self, implementation, enforce_implementation=True, **kwds): impl, extra_kwds = self.get_actual_implementation( implementation=implementation, enforce_implementation=enforce_implementation, **kwds, ) for k in extra_kwds.keys(): assert k not in kwds kwds.update(extra_kwds) super().__init__(implementation=impl, **kwds)
[docs] @classmethod def get_actual_implementation( cls, implementation, enforce_implementation=True, cl_env=None, **kwds ): """ Parameters ---------- implementation: Implementation, optional, defaults to None User desired target implementation. enforce_implementation: bool, optional, defaults to True If this is set to True, input implementation is enforced. Else, this function may select another implementation when some conditions are met: Case 1: Host FFT by mapping CPU OpenCL buffers Conditions: a/ input implementation is set to OPENCL b/ cl_env.device is of type CPU c/ Implementation.PYTHON is a valid operator implementation d/ Target python operator supports OPENCL as backend e/ OpenCL platform has zero copy capabilities (cannot be checked) => If cl_env is not given, this will yield a RuntimeError => In this case PYTHON implementation is chosen instead. Buffer are mapped to host. By default this should give multithread FFTW + multithreaded numba. For all other cases, this parameter is ignored. Notes ----- clFFT (gpyFFT) support for OpenCL CPU devices is a bit neglected. This function allows to override the implementation target from OPENCL to PYTHON when a CPU OpenCL environment is given as input. By default, the CPU FFT target is FFTW (pyFFTW) which has much better support (multithreaded fftw + multithreaded numba). OpenCL buffers are mapped to host memory with enqueue_map_buffer (this makes the assumption thal all OpenCL buffers have been allocated with zero-copy capability in the target OpenCL platform). """ implementation = first_not_None(implementation, cls.default_implementation()) assert implementation in cls.implementations() extra_kwds = {} if enforce_implementation: return (implementation, extra_kwds) if implementation == Implementation.OPENCL: if cl_env is None: msg = "enforce_implementation was set to False, " msg += "implementation is OPENCL, but no cl_env was passed " msg += "to check if the device is of type CPU." raise RuntimeError(msg) from hysop.backend.device.opencl import cl if cl_env.device.type == cl.device_type.CPU: if Implementation.PYTHON in cls.implementations(): from hysop.backend.host.host_operator import ( HostOperator, OpenClMappable, ) op_cls = cls.implementations()[Implementation.PYTHON] if not issubclass(op_cls, HostOperator): msg = "Operator {} is not a HostOperator." msg = msg.format(op_cls) raise TypeError(msg) if not issubclass(op_cls, OpenClMappable): msg = "Operator {} does not support host to device opencl buffer mapping." msg = msg.format(op_cls) raise TypeError(msg) assert Backend.HOST in op_cls.supported_backends() assert Backend.OPENCL in op_cls.supported_backends() extra_kwds["enable_opencl_host_buffer_mapping"] = True return (Implementation.PYTHON, extra_kwds) return (implementation, extra_kwds)
[docs] class SpectralOperatorBase: """ Common implementation interface for spectral based operators. """ min_fft_alignment = simd_alignment # FFTW SIMD. @debug def __new__(cls, fft_interface=None, fft_interface_kwds=None, **kwds): return super().__new__(cls, **kwds) @debug def __init__(self, fft_interface=None, fft_interface_kwds=None, **kwds): """ Initialize a spectral operator base. kwds: dict Base class keyword arguments. """ super().__init__(**kwds) check_instance(fft_interface, FFTI, allow_none=True) check_instance(fft_interface_kwds, dict, allow_none=True) self.transform_groups = {} # dict[tag] -> SpectralTransformGroup # those values will be deleted at discretization self._fft_interface = fft_interface self._fft_interface_kwds = fft_interface_kwds @property def backend(self): msg = "FFT array backend depends on the transform group. Please use op.transform_group[key].backend instead." raise AttributeError(msg) @property def FFTI(self): msg = "FFT interface depends on the transform group. Please use op.transform_group[key].FFTI instead." raise AttributeError(msg)
[docs] def new_transform_group(self, tag=None, mem_tag=None): """ Register a new SpectralTransformGroup to this spectral operator. A SpectralTransformGroup is an object that collect forward and backward field transforms as well as symbolic expressions and wave_numbers symbols. """ n = len(self.transform_groups) tag = first_not_None(tag, f"transform_group_{n}") msg = 'Tag "{}" has already been registered.' assert tag not in self.transform_groups, msg.format(tag) trg = SpectralTransformGroup(op=self, tag=tag, mem_tag=mem_tag) self.transform_groups[tag] = trg return trg
[docs] def pre_initialize(self, **kwds): output_parameters = set() for tg in self.transform_groups.values(): output_parameters.update(tg.output_parameters) for p in output_parameters: self.output_params.update({p})
[docs] def initialize(self, **kwds): super().initialize(**kwds) for tg in self.transform_groups.values(): tg.initialize(**kwds)
[docs] def get_field_requirements(self): requirements = super().get_field_requirements() for is_input, (field, td, req) in requirements.iter_requirements(): req.memory_order = MemoryOrdering.C_CONTIGUOUS req.axes = (TranspositionState[field.dim].default_axes(),) can_split = req.can_split can_split[-1] = False can_split[:-1] = True req.can_split = can_split return requirements
[docs] @debug def get_node_requirements(self): node_reqs = super().get_node_requirements() node_reqs.enforce_unique_topology_shape = True return node_reqs
[docs] def discretize(self, **kwds): super().discretize(**kwds) comm = self.mpi_params.comm size = self.mpi_params.size if size > 1: msg = "\n[FATAL ERROR] Spectral operators do not support the MPI interface yet." msg += "\nPlease use the Fortran FFTW interface if possible or " msg += "use another discretization method for operator {}.\n" msg = msg.format(self.node_tag) print(msg) raise NotImplementedError for tg in self.transform_groups.values(): tg.discretize( fft_interface=self._fft_interface, fft_interface_kwds=self._fft_interface_kwds, enable_opencl_host_buffer_mapping=self.enable_opencl_host_buffer_mapping, **kwds, ) del self._fft_interface del self._fft_interface_kwds
[docs] def get_mem_requests(self, **kwds): memory_requests = {} for tg in self.transform_groups.values(): for k, v in tg.get_mem_requests(**kwds).items(): check_instance(k, str) # temporary buffer name check_instance(v, (int, np.integer)) # nbytes K = (k, tg.backend) if K in memory_requests: memory_requests[K] = max(memory_requests[K], v) else: memory_requests[K] = v return memory_requests
[docs] def get_work_properties(self, **kwds): requests = super().get_work_properties(**kwds) for (k, backend), v in self.get_mem_requests(**kwds).items(): check_instance(k, str) check_instance(v, (int, np.integer)) if v > 0: mrequest = MemoryRequest( backend=backend, size=v, alignment=self.min_fft_alignment ) requests.push_mem_request(request_identifier=k, mem_request=mrequest) return requests
[docs] def setup(self, work): self.allocate_tmp_fields(work) for tg in self.transform_groups.values(): tg.setup(work=work) super().setup(work=work)
[docs] class SpectralTransformGroup: """ Build and check a FFT transform group. This object tells the planner to build a full forward transform for all given forward_fields. The planner will also build backward transforms for all specified backward_fields. The object will also automatically build per-axis wavenumbers up to certain powers, extracted from user provided sympy expressions. Finally boundary condition (ie. transform type) compability will be checked by using user provided sympy expressions. Calling a forward transform ensures that forward source field is read-only and not destroyed. """ DEBUG = False def __new__(cls, op, tag, mem_tag, **kwds): return super().__new__(cls, **kwds) def __init__(self, op, tag, mem_tag, **kwds): """ Parameters ---------- op : SpectralOperatorBase Operator that creates this SpectralTransformGroup. tag: str A tag to identify this transform group. Each tag can only be registered once in a SpectralOperatorBase instance. Attributes: ----------- tag: str SpectralTransformGroup identifier. mem_tag: str SpectralTransformGroup memory pool identifier. forward_transforms: list of forward SpectralTransform Forward fields to be planned for transform, according to Field boundary conditions. backward_fields: list of backward SpectralTransform Backward fields to be planned for transform, according to Field boundary conditions. Notes ----- All forward_fields and backward_fields have to live on the same domain and their boundary conditions should comply with given expressions. """ super().__init__(**kwds) mem_tag = first_not_None(mem_tag, "fft_pool") check_instance(op, SpectralOperatorBase) check_instance(tag, str) check_instance(mem_tag, str) self._op = op self._tag = tag self._mem_tag = mem_tag self._forward_transforms = {} self._backward_transforms = {} self._wave_numbers = set() self._indexed_wave_numbers = {} self._expressions = () self._discrete_wave_numbers = None
[docs] def indexed_wavenumbers(self, *wave_numbers): return tuple(self._indexed_wave_numbers[Wi] for Wi in wave_numbers)
@property def op(self): return self._op @property def tag(self): return self._tag @property def mem_tag(self): return self._mem_tag @property def name(self): return self._tag @property def initialized(self): return self._op.initialized @property def discretized(self): return self._op.discretized @property def ready(self): return self._op.ready @property def forward_fields(self): return tuple(map(lambda x: x[0], self._forward_transforms.keys())) @property def backward_fields(self): return tuple(map(lambda x: x[0], self._backward_transforms.keys())) @property def forward_transforms(self): return self._forward_transforms @property def backward_transforms(self): return self._backward_transforms
[docs] @_not_initialized def initialize( self, fft_granularity=None, fft_concurrent_plans=1, fft_plan_workload=1, **kwds ): """ Should be called after all require_forward_transform and require_backward_transform calls. Parameters ---------- fft_granularity: int, optional Granularity of each directional fft plan. 1: iterate over 1d lines (slices of dimension 1) 2: iterate over 2d planes (slices of dimension 2) 3: iterate over 3d blocks (slices of dimension 3) n-1: iterate over hyperplans (slices of dimension n-1) n : no iteration, the plan will handle the whole domain. Contiguous buffers with sufficient alignement are allocated. Default value is: 1 in 1D else n-1 (ie. hyperplans) fft_plan_workload: int, optional, defaults to 1 The number of blocks of dimension fft_granularity that a single plan will handle at once. Default is one block. fft_concurrent_plans: int, optional, defaults to 1 Number of concurrent plans. Should be 1 for HOST based FFT interfaces. Should be at least 3 for DEVICE based FFT interface if the device has two async copy engine (copy, transform, copy). """ (domain, dim) = self.check_fields(self.forward_fields, self.backward_fields) fft_granularity = first_not_None(fft_granularity, max(1, dim - 1)) check_instance(fft_granularity, int, minval=1, maxval=dim) check_instance(fft_concurrent_plans, int, minval=1) check_instance(fft_plan_workload, int, minval=1) self._fft_granularity = fft_granularity self._fft_concurrent_plans = fft_concurrent_plans self._fft_plan_workload = fft_plan_workload self._domain = domain self._dim = dim
[docs] @_initialized def discretize( self, fft_interface, fft_interface_kwds, enable_opencl_host_buffer_mapping, **kwds, ): backends = set() grid_resolutions = set() compute_axes = set() compute_shapes = set() compute_dtypes = set() for fwd in self.forward_transforms.values(): fwd.discretize() backends.add(fwd.backend) grid_resolutions.add(to_tuple(fwd.dfield.mesh.grid_resolution)) compute_axes.add(fwd.output_axes) compute_shapes.add(fwd.output_shape) compute_dtypes.add(fwd.output_dtype) for bwd in self.backward_transforms.values(): bwd.discretize() backends.add(bwd.backend) grid_resolutions.add(to_tuple(bwd.dfield.mesh.grid_resolution)) compute_axes.add(bwd.input_axes) compute_shapes.add(bwd.input_shape) compute_dtypes.add(bwd.input_dtype) def format_error(data): return "\n *" + "\n *".join(str(x) for x in data) msg = "Fields do not live on the same backend:" + format_error(backends) assert len(backends) == 1, msg msg = "Fields grid size mismatch:" + format_error(grid_resolutions) assert len(grid_resolutions) == 1, msg assert len(compute_axes) == 1, "Fields axes mismatch:" + format_error( compute_axes ) assert len(compute_shapes) == 1, "Fields shape mismatch:" + format_error( compute_shapes ) assert len(compute_dtypes) == 1, "Fields data type mismatch." + format_error( compute_dtypes ) backend = next(iter(backends)) grid_resolution = next(iter(grid_resolutions)) compute_axes = next(iter(compute_axes)) compute_shape = next(iter(compute_shapes)) compute_dtype = next(iter(compute_dtypes)) if enable_opencl_host_buffer_mapping: msg = "Trying to enable opencl device to host buffer mapping on {} target." assert backend.kind is Backend.OPENCL, msg.format(backend.kind) if fft_interface is None: fft_interface_kwds = first_not_None(fft_interface_kwds, {}) fft_interface = FFTI.default_interface_from_backend( backend, enable_opencl_host_buffer_mapping=enable_opencl_host_buffer_mapping, **fft_interface_kwds, ) else: assert not fft_interface_kwds, "FFT interface has already been built." check_instance(fft_interface, FFTI) fft_interface.check_backend( backend, enable_opencl_host_buffer_mapping=enable_opencl_host_buffer_mapping ) buffer_backend = backend host_backend = backend.host_array_backend backend = fft_interface.backend discrete_wave_numbers = {} for wn in self._wave_numbers: (idx, freqs, nd_freqs) = self.build_wave_number( self._domain, grid_resolution, backend, wn, compute_dtype, compute_axes, compute_shape, ) self._indexed_wave_numbers[wn].indexed_object.to_backend( backend.kind ).bind_memory_object(freqs) self._indexed_wave_numbers[wn].index.bind_axes(compute_axes) discrete_wave_numbers[wn] = (idx, freqs, nd_freqs) self._discrete_wave_numbers = discrete_wave_numbers self.buffer_backend = buffer_backend self.host_backend = host_backend self.backend = backend self.FFTI = fft_interface self.grid_resolution = grid_resolution self.compute_axes = compute_axes self.compute_shape = compute_shape self.compute_dtype = compute_dtype
[docs] @classmethod def build_wave_number( cls, domain, grid_resolution, backend, wave_number, compute_dtype, compute_axes, compute_resolution, ): dim = domain.dim length = domain.length ftype, ctype = determine_fp_types(compute_dtype) axis = wave_number.axis transform = wave_number.transform exponent = wave_number.exponent idx = compute_axes.index(axis) L = domain.length[axis] N = grid_resolution[axis] freqs = STU.compute_wave_numbers(transform=transform, N=N, L=L, ftype=ftype) freqs = freqs**exponent if STU.is_R2R(transform): sign_offset = STU.is_cosine(transform) freqs *= (-1) ** ((exponent + sign_offset) // 2) assert exponent != 0, "exponent cannot be zero." assert exponent > 0, "negative powers not implemented yet." if is_complex(freqs.dtype) and (exponent % 2 == 0): assert freqs.imag.sum() == 0 freqs = freqs.real.copy() backend_freqs = backend.empty_like(freqs) backend_freqs[...] = freqs freqs = backend_freqs nd_shape = [ 1, ] * dim nd_shape[idx] = freqs.size nd_shape = tuple(nd_shape) nd_freqs = freqs.reshape(nd_shape) if cls.DEBUG: print() print("BUILD WAVENUMBER") print(f"backend: {backend.kind}") print(f"grid_shape: {grid_resolution}") print(f"length: {length}") print("-----") print(f"ftype: {ftype}") print(f"ctype: {ctype}") print(f"compute shape: {compute_resolution}") print(f"compute axes: {compute_axes}") print("-----") print("wave_number:") print(f" *symbolic: {wave_number}") print(f" *axis: {axis}") print(f" *transform: {transform}") print(f" *exponent: {exponent}") print("----") print(f"L: {L}") print(f"N: {N}") print(f"freqs: {freqs}") print(f"nd_freqs: {nd_freqs}") print("----") return (idx, freqs, nd_freqs)
[docs] @_discretized def get_mem_requests(self, **kwds): memory_requests = {} for fwd in self.forward_transforms.values(): mem_requests = fwd.get_mem_requests(**kwds) check_instance(mem_requests, dict, keys=str, values=(int, np.integer)) for k, v in mem_requests.items(): if k in memory_requests: memory_requests[k] = max(memory_requests[k], v) else: memory_requests[k] = v for bwd in self.backward_transforms.values(): mem_requests = bwd.get_mem_requests(**kwds) check_instance(mem_requests, dict, keys=str, values=(int, np.integer)) for k, v in mem_requests.items(): if k in memory_requests: memory_requests[k] = max(memory_requests[k], v) else: memory_requests[k] = v return memory_requests
[docs] @_discretized def setup(self, work): for fwd in self.forward_transforms.values(): fwd.setup(work=work) for bwd in self.backward_transforms.values(): bwd.setup(work=work)
[docs] @_not_initialized def require_forward_transform( self, field, axes=None, transform_tag=None, custom_output_buffer=None, action=None, dump_energy=None, plot_energy=None, **kwds, ): """ Tells this SpectralTransformGroup to build a forward SpectralTransform on given field. Only specified axes are transformed. Boundary condition to FFT extension mapping: Periodic: Periodic extension Homogeneous Dirichlet: Odd extension Homogeneous Neumann: Even extension This leads to 5 possible transforms for each axis (periodic-periodic, even-even, odd-odd, even-odd, odd-even). Forward transforms used for each axis per extension pair: *Periodic-Periodic (PER-PER): DFT (C2C, R2C for the first periodic axis) *Dirichlet-Dirichlet (ODD-ODD): DST-I *Dirichlet-Neumann (ODD-EVEN): DST-III *Neumann-Dirichlet (EVEN-ODD): DCT-III *Neumann-Neumann (EVEN-EVEN): DCT-I This method will return the SpectralTransform object associated to field. Parameters ---------- field: ScalarField The source field to be transformed. axes: array-like of integers The axes to be transformed. transform_tag: str Extra tag to register the forward transform (a single scalar field can be transformed multiple times). Default tag is 'default'. custom_output_buffer: None or str, optional Force this transform to output in one of the two common transform group buffers. Default None value will force the user allocate an output buffer. Specifying 'B0' or 'B1' will tell the planner to output the transform in one of the two transform group buffers (that are used during all forward and backward transforms of the same transform group). This features allow FFT operators to save one buffer for the last forward transform. Specifying 'auto' will tell the planner to choose either 'B0' or 'B1'. action: BackwardTransfromAction, optional Defaults to SpectralTransformAction.OVERWRITE which will overwrite the compute slices of the output buffer. SpectralTransformAction.ACCUMULATE will sum the current content of the buffer with the result of the forward transform. dump_energy: IOParams, optional, defaults to None Compute the energy for each wavenumber at given frequency after each transform. If None is passed, no files are generated (default behaviour). plot_energy: IOParams, optional, defaults to None Plot field energy after each call to the forward transform to a custom file. If None is passed, no plots are generated (default behaviour). compute_energy_frequencies: array like of integers, optional, defaults to None Extra frequencies where to compute energy. Notes ----- IOParams filename is formatted before being used: {fname} is replaced with discrete field name {ite} is replaced with simulation iteration id for plotting and '' for file dumping. dump_energy plot_energy result None None nothing iop0 0 energy is computed and dumped every iop0.frequency iterations 0 iop1 energy is computed and dumped every iop1.frequency iterations iop0 iop1 energy is computed every iop1.frequency and iop2.frequency iterations dumped every iop0.frequency plotted every iop1.frequency About frequency: if (frequency<0) no dump if (frequency==0) dump at time of interests and last iteration if (frequency>=0) dump at time of interests, last iteration and every freq iterations """ transform_tag = first_not_None(transform_tag, "default") action = first_not_None(action, SpectralTransformAction.OVERWRITE) transforms = SpectralTransform(field=field, axes=axes, forward=False) check_instance(field, Field) check_instance(transform_tag, str) check_instance(action, SpectralTransformAction) transforms = SpectralTransform(field=field, axes=axes, forward=True) msg = 'Field {} with axes {} and transform_tag "{}" has already been registered for forward transform.' if field.is_tensor: planned_transforms = field.new_empty_array() for idx, f in field.nd_iter(): assert ( f, axes, transform_tag, ) not in self._forward_transforms, msg.format( f.name, axes, transform_tag ) assert f in self._op.input_fields assert f is transforms[idx].field assert transforms[idx].is_forward planned_transforms[idx] = PlannedSpectralTransform( transform_group=self, tag=self.tag + "_" + transform_tag + "_" + f.name, symbolic_transform=transforms[idx], custom_output_buffer=custom_output_buffer, action=action, dump_energy=dump_energy, plot_energy=plot_energy, **kwds, ) self._forward_transforms[(f, axes, transform_tag)] = planned_transforms[ idx ] else: assert ( field, axes, transform_tag, ) not in self._forward_transforms, msg.format( field.name, axes, transform_tag ) assert field in self._op.input_fields assert field is transforms.field assert transforms.is_forward planned_transforms = PlannedSpectralTransform( transform_group=self, tag=self.tag + "_" + transform_tag + "_" + field.name, symbolic_transform=transforms, custom_output_buffer=custom_output_buffer, action=action, dump_energy=dump_energy, plot_energy=plot_energy, **kwds, ) self._forward_transforms[(field, axes, transform_tag)] = planned_transforms return planned_transforms
[docs] @_not_initialized def require_backward_transform( self, field, axes=None, transform_tag=None, custom_input_buffer=None, matching_forward_transform=None, action=None, dump_energy=None, plot_energy=None, **kwds, ): """ Same as require_forward_transform but for backward transforms. This corresponds to the following backward transform mappings: if order[axis] is 0: *no transform -> no transform else, if order[axis] is even: *C2C -> C2C *R2C -> C2R *DCT-I -> DCT-I *DCT-III -> DCT-II *DST-I -> DST-I *DST-III -> DST-II else: (if order[axis] is odd) *C2C -> C2C *R2C -> C2R *DCT-I -> DST-I *DCT-III -> DST-II *DST-I -> DCT-I *DST-III -> DCT-II For backward transforms, boundary compatibility for output_fields is thus the following: if axis is even: Boundary should be exactly the same on the axis. else, if axis is odd, boundary conditions change on this axe: *(Periodic-Peridic) PER-PER -> PER-PER (Periodic-Periodic) *(Neumann-Neumann) EVEN-EVEN -> ODD-ODD (Dirichlet-Dirichlet) *(Neumann-Dirichlet) EVEN-ODD -> ODD-EVEN (Dirichlet-Neumann) *(Dirichlet-Neumann) ODD-EVEN -> EVEN-ODD (Neumman-Dirichlet) *(Dirichlet-Dirichlet) ODD-ODD -> EVEN-EVEN (Neumann-Neumann) Order and boundary conditions are decuded from field. Parameters ---------- field: ScalarField The target field where the result of the inverse transform will be stored. axes: array-like of integers The axes to be transformed. transform_tag: str Extra tag to register the backward transform (a single scalar field can be transformed multiple times). Default tag is 'default'. custom_input_buffer: None or str or F, optional Force this transform to take as input one of the two common transform group buffers. Default None value will force the user to supply an input buffer. Specifying 'B0' or 'B1' will tell the planner to take as transform input one of the two transform group buffers (that are used during all forward and backward transforms of the same transform group). This features allow FFT operators to save one buffer for the first backward transform. Specifying 'auto' will tell the planner to use the matching transform output buffer. action: BackwardTransfromAction, optional Defaults to SpectralTransformAction.OVERWRITE which will overwrite the compute slices of the given output field. SpectralTransformAction.ACCUMULATE will sum the current content of the field with the result of the backward transform. dump_energy: IOParams, optional, defaults to None Compute the energy for each wavenumber at given frequency before each transform. If None is passed, no files are generated (default behaviour). plot_energy: IOParams, optional, defaults to None Plot field energy before each call to the backward transform to a custom file. If None is passed, no plots are generated (default behaviour). compute_energy_frequencies: array like of integers, optional, defaults to None Extra frequencies where to compute energy. Notes ----- IOParams filename is formatted before being used: {fname} is replaced with discrete field name {ite} is replaced with simulation iteration id for plotting and '' for file dumping. dump_energy plot_energy result None None nothing iop0 0 energy is computed and dumped every iop0.frequency iterations 0 iop1 energy is computed and dumped every iop1.frequency iterations iop0 iop1 energy is computed every iop1.frequency and iop2.frequency iterations dumped every iop0.frequency plotted every iop1.frequency About frequency: if (frequency<0) no dump if (frequency==0) dump at time of interests and last iteration if (frequency>=0) dump at time of interests, last iteration and every freq iterations """ transform_tag = first_not_None(transform_tag, "default") action = first_not_None(action, SpectralTransformAction.OVERWRITE) check_instance(field, Field) check_instance(transform_tag, str) check_instance(action, SpectralTransformAction) transforms = SpectralTransform(field=field, axes=axes, forward=False) msg = 'Field {} with axes {} and transform_tag "{}" has already been registered for backward transform.' if field.is_tensor: planned_transforms = field.new_empty_array() for idx, f in field.nd_iter(): assert ( f, axes, transform_tag, ) not in self._backward_transforms, msg.format( f.name, axes, transform_tag ) assert f in self._op.output_fields assert not transforms[idx].is_forward planned_transforms[idx] = PlannedSpectralTransform( transform_group=self, tag=self.tag + "_" + transform_tag + "_" + f.name, symbolic_transform=transforms[idx], custom_input_buffer=custom_input_buffer, matching_forward_transform=matching_forward_transform, action=action, dump_energy=dump_energy, plot_energy=plot_energy, **kwds, ) self._backward_transforms[(f, axes, transform_tag)] = ( planned_transforms[idx] ) else: assert ( field, axes, transform_tag, ) not in self._backward_transforms, msg.format( field.name, axes, transform_tag ) assert field in self._op.output_fields assert not transforms.is_forward planned_transforms = PlannedSpectralTransform( transform_group=self, tag=self.tag + "_" + transform_tag + "_" + field.name, symbolic_transform=transforms, custom_input_buffer=custom_input_buffer, matching_forward_transform=matching_forward_transform, action=action, dump_energy=dump_energy, plot_energy=plot_energy, **kwds, ) self._backward_transforms[(field, axes, transform_tag)] = planned_transforms return planned_transforms
@property def output_parameters(self): parameters = set() for pt in tuple(self._forward_transforms.values()) + tuple( self._backward_transforms.values() ): parameters.update(pt.output_parameters) return parameters @property def discrete_wave_numbers(self): assert self.discretized discrete_wave_numbers = self._discrete_wave_numbers if discrete_wave_numbers is None: msg = "discrete_wave_numbers has not been set yet." raise AttributeError(msg) return self._discrete_wave_numbers
[docs] @_not_initialized def push_expressions(self, *exprs): exprs_wave_numbers = set() for expr in exprs: assert isinstance(expr, sm.Basic) (e, transforms, wn) = STU.parse_expression(expr, replace_pows=True) self._expressions += (e,) self._wave_numbers.update(wn) for _wn in wn: if _wn not in self._indexed_wave_numbers: self._indexed_wave_numbers[_wn] = _wn.indexed_buffer() exprs_wave_numbers.update(wn) if self.DEBUG: print(f"\n\nPARSING EXPRESSION {expr}") print(f" new_expr: {e}") print(f" transforms: {transforms}") print(f" wave_numbers: {wn}") return tuple(exprs_wave_numbers)
[docs] @classmethod def check_fields(cls, forward_fields, backward_fields): all_fields = tuple(set(forward_fields + backward_fields)) if not all_fields: msg = "At least one field is required." raise ValueError(msg) domain = cls.determine_domain(*all_fields) dim = domain.dim return (domain, dim)
[docs] @classmethod def determine_domain(cls, *fields): domain = fields[0].domain for field in fields[1:]: if field.domain is not domain: msg = "Domain mismatch between fields:\n{}\nvs.\n{}\n" msg = msg.format(domain, field.domain) raise ValueError(msg) return domain
[docs] class PlannedSpectralTransform: """ A planned spectral transform is an AppliedSpectralTransform wrapper. This object will be handled by the transform planner. """ DEBUG = False def __new__( cls, transform_group, tag, symbolic_transform, action, custom_input_buffer=None, custom_output_buffer=None, matching_forward_transform=None, dump_energy=None, plot_energy=None, compute_energy_frequencies=None, **kwds, ): return super().__new__(cls, **kwds) def __init__( self, transform_group, tag, symbolic_transform, action, custom_input_buffer=None, custom_output_buffer=None, matching_forward_transform=None, dump_energy=None, plot_energy=None, compute_energy_frequencies=None, **kwds, ): super().__init__(**kwds) check_instance(transform_group, SpectralTransformGroup) check_instance(transform_group.op, SpectralOperatorBase) check_instance(tag, str) check_instance(symbolic_transform, AppliedSpectralTransform) check_instance(action, SpectralTransformAction) check_instance(dump_energy, IOParams, allow_none=True) check_instance(plot_energy, IOParams, allow_none=True) assert custom_input_buffer in (None, "B0", "B1", "auto"), custom_input_buffer assert custom_output_buffer in (None, "B0", "B1", "auto"), custom_output_buffer field = symbolic_transform.field is_forward = symbolic_transform.is_forward self._transform_group = transform_group self._tag = tag self._symbol = symbolic_transform self._queue = None self._custom_input_buffer = custom_input_buffer self._custom_output_buffer = custom_output_buffer self._matching_forward_transform = matching_forward_transform self._action = action self._do_dump_energy = (dump_energy is not None) and ( dump_energy.frequency >= 0 ) self._do_plot_energy = (plot_energy is not None) and ( plot_energy.frequency >= 0 ) compute_energy_frequencies = to_set( first_not_None(compute_energy_frequencies, set()) ) if self._do_dump_energy: compute_energy_frequencies.add(dump_energy.frequency) if self._do_plot_energy: compute_energy_frequencies.add(plot_energy.frequency) compute_energy_frequencies = set( filter(lambda f: f >= 0, compute_energy_frequencies) ) do_compute_energy = len(compute_energy_frequencies) > 0 self._do_compute_energy = do_compute_energy self._compute_energy_frequencies = compute_energy_frequencies self._plot_energy_ioparams = plot_energy self._dump_energy_ioparams = dump_energy if self._do_compute_energy: ename = "E{}_{}".format("f" if is_forward else "b", field.name) pename = "E{}_{}".format("f" if is_forward else "b", field.pretty_name) vename = "E{}_{}".format("f" if is_forward else "b", field.var_name) self._energy_parameter = BufferParameter( name=ename, pretty_name=pename, var_name=vename, shape=None, dtype=None, initial_value=None, ) else: self._energy_parameter = None self._energy_dumper = None self._energy_plotter = None if is_forward: msg = "Cannot specify 'custom_input_buffer' for a forward transform." assert custom_input_buffer is None, msg msg = "Cannot specify 'matching_forward_transform' for a forward transform." assert matching_forward_transform is None, msg else: msg = "Cannot specify 'custom_output_buffer' for a backward transform." assert self._custom_output_buffer is None, msg if self._custom_input_buffer == "auto": msg = "Using 'auto' as 'custom_output_buffer' of a backward transform implies " msg += "to specify a 'matching_forward_transform' to choose the buffer from." assert matching_forward_transform is not None, msg assert isinstance( matching_forward_transform, PlannedSpectralTransform ), msg assert matching_forward_transform.is_forward, msg else: msg = ( "Using 'custom_output_buffer' different than 'auto' for a backward " ) msg += "transform implies to set 'matching_forward_transform' to None." assert matching_forward_transform is None, msg # reorder transforms in execution order (contiguous axe first) transforms = self.s.transforms[::-1] if len(transforms) != field.dim: msg = "Number of transforms does not match field dimension." raise ValueError(msg) if all((tr is TransformType.NONE) for tr in transforms): msg = "All transforms are of type NONE." raise ValueError(msg) if is_forward: input_dtype = field.dtype output_dtype = STU.determine_output_dtype(field.dtype, *transforms) else: input_dtype = STU.determine_input_dtype(field.dtype, *transforms) output_dtype = field.dtype self._input_dtype = np.dtype(input_dtype) self._output_dtype = np.dtype(output_dtype) self._input_shape = None self._output_shape = None self._input_buffer = None self._output_buffer = None self._dfield = None self._input_symbolic_arrays = set() self._output_symbolic_arrays = set() self._ready = False @property def output_parameters(self): return {self._energy_parameter} - {None}
[docs] def input_symbolic_array(self, name, **kwds): """Create a symbolic array that will be bound to input transform array.""" assert "memory_object" not in kwds assert "dim" not in kwds obj = SymbolicArray(name=name, memory_object=None, dim=self.field.dim, **kwds) self._input_symbolic_arrays.add(obj) return obj
[docs] def output_symbolic_array(self, name, **kwds): """Create a symbolic array that will be bound to output transform array.""" assert "memory_object" not in kwds assert "dim" not in kwds obj = SymbolicArray(name=name, memory_object=None, dim=self.field.dim, **kwds) self._output_symbolic_arrays.add(obj) return obj
@property def transform_group(self): return self._transform_group @property def op(self): return self._transform_group.op @property def tag(self): return self._tag @property def name(self): return self._tag @property def symbol(self): return self._symbol @property def s(self): return self._symbol @property def field(self): return self._symbol.field @property def is_forward(self): return self._symbol.is_forward @property def is_backward(self): return not self.is_forward @property def transforms(self): return self._symbol.transforms @property def input_dtype(self): return self._input_dtype @property def output_dtype(self): return self._output_dtype @property def backend(self): assert self.discretized backend = self._backend if backend is None: msg = "backend has not been set yet." raise AttributeError(msg) return backend @property def dfield(self): assert self.discretized if self._dfield is None: msg = "dfield has not been set." raise AttributeError(msg) return self._dfield @property def input_shape(self): assert self.discretized if self._input_shape is None: msg = "input_shape has not been set." raise AttributeError(msg) return self._input_shape @property def output_shape(self): assert self.discretized if self._output_shape is None: msg = "output_shape has not been set." raise AttributeError(msg) return self._output_shape @property def input_transform_shape(self): assert self.discretized if self._input_transform_shape is None: msg = "input_transform_shape has not been set." raise AttributeError(msg) return self._input_transform_shape @property def output_transform_shape(self): assert self.discretized if self._output_transform_shape is None: msg = "output_transform_shape has not been set." raise AttributeError(msg) return self._output_transform_shape @property def input_axes(self): assert self.discretized if self._input_axes is None: msg = "input_axes has not been set." raise AttributeError(msg) return self._input_axes @property def output_axes(self): assert self.discretized if self._output_axes is None: msg = "output_axes has not been set." raise AttributeError(msg) return self._output_axes @property def input_slices(self): assert self.discretized buf = self._input_slices if buf is None: msg = "input_slices has not been set yet." raise AttributeError(msg) return buf @property def output_slices(self): assert self.discretized buf = self._output_slices if buf is None: msg = "output_slices has not been set yet." raise AttributeError(msg) return buf @property def input_buffer(self): assert self.discretized buf = self._input_buffer if buf is None: msg = "input_buffer has not been set yet." raise AttributeError(msg) return buf @property def output_buffer(self): assert self.discretized buf = self._output_buffer if buf is None: msg = "output_buffer has not been set yet." raise AttributeError(msg) return buf @property def full_input_buffer(self): assert self.discretized buf = self._full_input_buffer if buf is None: msg = "full_input_buffer has not been set yet." raise AttributeError(msg) return buf @property def full_output_buffer(self): assert self.discretized buf = self._full_output_buffer if buf is None: msg = "full_output_buffer has not been set yet." raise AttributeError(msg) return buf @property def initialized(self): return self.op.initialized @property def discretized(self): return self.op.discretized @property def ready(self): return self._ready
[docs] @_not_initialized def initialize(self, **kwds): pass
[docs] @_initialized def discretize(self, **kwds): is_forward = self.is_forward dim = self.field.dim field_axes = TranspositionState[dim].default_axes() if is_forward: (dfield, transform_info, transpose_info, transform_offsets) = ( self._discretize_forward(field_axes, **kwds) ) assert transpose_info[0][1] == field_axes else: (dfield, transform_info, transpose_info, transform_offsets) = ( self._discretize_backward(field_axes, **kwds) ) assert transpose_info[-1][2] == field_axes assert dfield.dim == len(transform_info) == len(transpose_info) == dim assert transform_info[0][2][1] == self._input_dtype assert transform_info[-1][3][1] == self._output_dtype # filter out untransformed axes tidx = tuple( filter(lambda i: not STU.is_none(transform_info[i][1]), range(dim)) ) assert tidx, "Could not determine any transformed axe." ntransforms = len(tidx) transform_info = tuple(map(transform_info.__getitem__, tidx)) transpose_info = tuple(map(transpose_info.__getitem__, tidx)) assert len(transform_info) == len(transpose_info) == ntransforms # determine input and output shapes input_axes = transpose_info[0][1] output_axes = transpose_info[-1][2] if is_forward: assert field_axes == input_axes, (field_axes, input_axes) input_transform_shape = transpose_info[0][3] output_transform_shape = transform_info[-1][3][0] input_shape, input_slices, _ = self.determine_buffer_shape( input_transform_shape, False, transform_offsets, input_axes ) output_shape, output_slices, zfos = self.determine_buffer_shape( output_transform_shape, True, transform_offsets, output_axes ) # We have a situation where we should impose zeros: # 1) output transform ghosts (when there are transform sizes mismatch DXT-I variants) zero_fill_output_slices = zfos else: assert field_axes == output_axes, (field_axes, output_axes) input_transform_shape = transform_info[0][2][0] output_transform_shape = transpose_info[-1][4] input_shape, input_slices, _ = self.determine_buffer_shape( input_transform_shape, True, transform_offsets, input_axes ) output_shape, output_slices, zfos = self.determine_buffer_shape( output_transform_shape, False, transform_offsets, output_axes ) # We have a situation where we should impose zeros: # 1) impose homogeneous dirichlet conditions on output # (implicit 0's are not part of the transform output). zero_fill_output_slices = zfos axes = output_axes if is_forward else input_axes ptransforms = tuple(self.transforms[i] for i in axes) self._permuted_transforms = ptransforms if self._do_compute_energy: shape = output_shape if is_forward else input_shape # view = (output_slices if is_forward else input_slices) assert len(shape) == ntransforms shape = tuple( Si - 2 if sum(transform_offsets[i]) == 2 else Si for i, Si in zip(axes, shape) ) K2 = () for tr, Ni in zip(ptransforms, shape): Ki = Ni // 2 if STU.is_C2C(tr) else Ni - 1 K2 += (Ki * Ki,) max_wavenumber = int(round(sum(K2) ** 0.5, 0)) energy_nbytes = compute_nbytes(max_wavenumber + 1, dfield.dtype) if dfield.backend.kind == Backend.OPENCL: mutexes_nbytes = compute_nbytes(max_wavenumber + 1, np.int32) else: mutexes_nbytes = 0 self._max_wavenumber = max_wavenumber self._energy_nbytes = energy_nbytes self._mutexes_nbytes = mutexes_nbytes Ep = self._energy_parameter Ep.reallocate_buffer(shape=(max_wavenumber + 1,), dtype=dfield.dtype) fname = fname = "{}{}".format(dfield.name, "_in" if is_forward else "_out") # build txt dumper if self._do_dump_energy: diop = self._dump_energy_ioparams assert diop is not None self._energy_dumper = EnergyDumper( energy_parameter=Ep, io_params=self._dump_energy_ioparams, fname=fname, ) # build plotter if required if self._do_plot_energy: piop = self._plot_energy_ioparams assert piop is not None pname = "{}.{}.{}".format( self.op.__class__.__name__, "forward" if is_forward else "backward", dfield.pretty_name, ) energy_parameters = {pname: self._energy_parameter} self._energy_plotter = EnergyPlotter( energy_parameters=energy_parameters, io_params=self._plot_energy_ioparams, fname=fname, ) else: self._max_wavenumber = None self._energy_nbytes = None self._mutexes_nbytes = None self._dfield = dfield self._transform_info = transform_info self._transpose_info = transpose_info self._ntransforms = ntransforms self._input_axes = input_axes self._input_shape = input_shape self._input_slices = input_slices self._input_transform_shape = input_transform_shape self._output_axes = output_axes self._output_shape = output_shape self._output_slices = output_slices self._output_transform_shape = output_transform_shape self._zero_fill_output_slices = zero_fill_output_slices self._backend = dfield.backend if self.DEBUG: def axis_format(info): prefix = "\n" + " " * 4 ss = "" for i, data in enumerate(info): ss += prefix + f"{i}/ " + str(data) return ss def slc_format(slices): if slices is None: return "NONE" else: prefix = "\n" + " " * 4 ss = "" for slc in slices: ss += prefix + str(slc) return ss print(f"\n\n== SPECTRAL PLANNING INFO OF FIELD {dfield.pretty_name} ==") print( "transform direction: {}".format( "FORWARD" if self.is_forward else "BACKWARD" ) ) print(f"transforms: {self.transforms}") print(":CARTESIAN INFO:") print(f"cart shape: {dfield.topology.cart_shape}") print(f"global grid resolution: {dfield.mesh.grid_resolution}") print(f"local grid resolution: {dfield.compute_resolution}") print(":INPUT:") print(f"input axes: {self._input_axes}") print(f"input dtype: {self._input_dtype}") print(f"input transform shape: {self._input_transform_shape}") print(f"input shape: {self._input_shape}") print(f"input slices: {self._input_slices}") print(":OUTPUT:") print(f"output axes: {self._output_axes}") print(f"output_dtype: {self._output_dtype}") print(f"output transform shape: {self._output_transform_shape}") print(f"output shape: {self._output_shape}") print(f"output_slices: {self._output_slices}") print(":TRANSFORM INFO:") print(f"transform_info: {axis_format(transform_info)}") print(":TRANSPOSE INFO:") print(f"transpose_info: {axis_format(transpose_info)}") print(":ZERO FILL:") print( f"zero_fill_output_slices: {slc_format(self._zero_fill_output_slices)}" )
[docs] def get_mapped_input_buffer(self): return self.get_mapped_full_input_buffer()[self.input_slices]
[docs] def get_mapped_output_buffer(self): return self.get_mapped_full_output_buffer()[self.output_slices]
[docs] def get_mapped_full_input_buffer(self): dfield = self._dfield if ( self.is_forward and dfield.backend.kind == Backend.OPENCL and self.transform_group._op.enable_opencl_host_buffer_mapping ): return self.transform_group._op.get_mapped_object(dfield)[ dfield.compute_slices ] else: return self.full_input_buffer
[docs] def get_mapped_full_output_buffer(self): dfield = self._dfield if ( self.is_backward and dfield.backend.kind == Backend.OPENCL and self.transform_group._op.enable_opencl_host_buffer_mapping ): return self.transform_group._op.get_mapped_object(dfield)[ dfield.compute_slices ] else: return self.full_output_buffer
[docs] def determine_buffer_shape(self, transform_shape, target_is_buffer, offsets, axes): offsets = tuple(offsets[ai] for ai in axes) slices = [] shape = [] zero_fill_slices = [] dim = len(axes) for i, ((lo, ro), si) in enumerate(zip(offsets, transform_shape)): if (lo ^ ro) and target_is_buffer: Si = si slc = slice(0, si) else: Si = si + lo + ro slc = slice(lo, Si - ro) if lo > 0: zfill = [slice(None, None, None)] * dim zfill[i] = slice(0, lo) zfill = tuple(zfill) zero_fill_slices.append(zfill) if ro > 0: zfill = [slice(None, None, None)] * dim zfill[i] = slice(Si - ro, Si) zfill = tuple(zfill) zero_fill_slices.append(zfill) shape.append(Si) slices.append(slc) return tuple(shape), tuple(slices), tuple(zero_fill_slices)
[docs] def configure_input_buffer(self, buf): input_dtype, input_shape = self.input_dtype, self.input_shape buf_nbytes = compute_nbytes(buf.shape, buf.dtype) input_nbytes = compute_nbytes(input_shape, input_dtype) assert buf_nbytes >= input_nbytes, (buf_nbytes, input_nbytes) if (buf.shape != input_shape) or (buf.dtype != input_dtype): buf = ( buf.view(dtype=np.int8)[:input_nbytes] .view(dtype=input_dtype) .reshape(input_shape) ) if isinstance(buf, Array): buf = buf.handle input_buffer = buf[self.input_slices] assert input_buffer.shape == self.input_transform_shape self._full_input_buffer = buf self._input_buffer = input_buffer for symbol in self._input_symbolic_arrays: symbol.to_backend(self.backend.kind).bind_memory_object(buf) return input_buffer
[docs] def configure_output_buffer(self, buf): output_dtype, output_shape = self.output_dtype, self.output_shape buf_nbytes = compute_nbytes(buf.shape, buf.dtype) output_nbytes = compute_nbytes(output_shape, output_dtype) assert buf_nbytes >= output_nbytes, (buf_nbytes, output_nbytes) if (buf.shape != output_shape) or (buf.dtype != output_dtype): buf = ( buf.view(dtype=np.int8)[:output_nbytes] .view(dtype=output_dtype) .reshape(output_shape) ) if isinstance(buf, Array): buf = buf.handle output_buffer = buf[self.output_slices] assert output_buffer.shape == self.output_transform_shape self._full_output_buffer = buf self._output_buffer = output_buffer for symbol in self._output_symbolic_arrays: symbol.to_backend(self.backend.kind).bind_memory_object(buf) return output_buffer
def _discretize_forward(self, field_axes, **kwds): dfield = self.op.input_discrete_fields[self.field] grid_resolution = dfield.mesh.grid_resolution local_resolution = dfield.compute_resolution input_dtype = dfield.dtype dim = dfield.dim forward_transforms = self.transforms[::-1] backward_transforms = STU.get_inverse_transforms(*forward_transforms) (resolution, transform_offsets) = STU.get_transform_resolution( local_resolution, *forward_transforms ) local_transform_info = self._determine_transform_info( forward_transforms, resolution, input_dtype ) local_transpose_info = self._determine_transpose_info( field_axes, local_transform_info ) local_transform_info = self._permute_transform_info( local_transform_info, local_transpose_info ) transform_info = local_transform_info transpose_info = local_transpose_info return (dfield, transform_info, transpose_info, transform_offsets) def _discretize_backward(self, field_axes, **kwds): forward_transforms = self.transforms[::-1] backward_transforms = STU.get_inverse_transforms(*forward_transforms) def reverse_transform_info(transform_info): transform_info = list(transform_info) for i, d in enumerate(transform_info): d = list(d) d[1] = forward_transforms[i] d2, d3 = d[2:4] d[2:4] = d3, d2 transform_info[i] = tuple(d) transform_info = tuple(transform_info) return transform_info[::-1] def reverse_transpose_info(transpose_info): transpose_info = list(transpose_info) for i, d in enumerate(transpose_info): if d[0] is not None: d = list(d) d1, d2, d3, d4 = d[1:5] d[1:5] = d2, d1, d4, d3 d[0] = tuple(d[1].index(ai) for ai in d[2]) d = tuple(d) else: # no permutation assert d[1] == d[2] assert d[3] == d[4] transpose_info[i] = d return transpose_info[::-1] dfield = self.op.output_discrete_fields[self.field] grid_resolution = dfield.mesh.grid_resolution local_resolution = dfield.compute_resolution output_dtype = dfield.dtype dim = dfield.dim (resolution, transform_offsets) = STU.get_transform_resolution( local_resolution, *backward_transforms ) local_backward_transform_info = self._determine_transform_info( backward_transforms, resolution, output_dtype ) local_backward_transpose_info = self._determine_transpose_info( field_axes, local_backward_transform_info ) local_backward_transform_info = self._permute_transform_info( local_backward_transform_info, local_backward_transpose_info ) local_forward_transform_info = reverse_transform_info( local_backward_transform_info ) local_forward_transpose_info = reverse_transpose_info( local_backward_transpose_info ) transform_info = local_forward_transform_info transpose_info = local_forward_transpose_info return (dfield, transform_info, transpose_info, transform_offsets) @classmethod def _determine_transform_info(cls, transforms, src_shape, src_dtype): transform_info = [] dim = len(transforms) dst_shape, dst_dtype = src_shape, src_dtype dst_view = [slice(0, si) for si in src_shape] for i, tr in enumerate(transforms): axis = i src_shape = dst_shape src_dtype = dst_dtype src_view = dst_view if STU.is_none(tr): pass elif STU.is_backward(tr): msg = "{} is not a forward transform." msg = msg.format(tr) raise ValueError(msg) elif STU.is_R2R(tr): msg = f"Expected a floating point data type but got {src_dtype}." assert is_fp(src_dtype), msg # data type and shape does not change elif STU.is_R2C(tr): msg = f"Expected a floating point data type but got {src_dtype}." assert is_fp(src_dtype), msg dst_shape = list(src_shape) dst_shape[dim - axis - 1] = dst_shape[dim - axis - 1] // 2 + 1 dst_shape = tuple(dst_shape) dst_dtype = float_to_complex_dtype(src_dtype) elif STU.is_C2C(tr): msg = f"Expected a complex data type but got {src_dtype}." assert is_complex(src_dtype), msg # data type and shape does not change else: msg = f"Unknown transform type {tr}." raise ValueError(msg) (lo, ro) = STU.get_transform_offsets(tr) src_view = src_view[:] src_view[dim - axis - 1] = slice(lo, src_shape[dim - axis - 1] - ro) dst_view = src_view[:] dst_view[dim - axis - 1] = slice(lo, dst_shape[dim - axis - 1] - ro) src_dtype = np.dtype(src_dtype) dst_dtype = np.dtype(dst_dtype) data = ( axis, tr, (src_shape, src_dtype, tuple(src_view)), (dst_shape, dst_dtype, tuple(dst_view)), ) transform_info.append(data) transform_info = tuple(transform_info) return transform_info @classmethod def _determine_transpose_info(cls, src_axes, transform_info): transpose_info = [] dim = len(src_axes) for ( axis, tr, (src_shape, src_dtype, src_view), (dst_shape, dst_dtype, dst_view), ) in transform_info: dst_axis = dim - 1 - axis if (not STU.is_none(tr)) and (dst_axis != src_axes[-1]): idx = src_axes.index(dst_axis) dst_axes = list(src_axes) dst_axes[idx] = src_axes[-1] dst_axes[-1] = dst_axis dst_axes = tuple(dst_axes) permutation = tuple(src_axes.index(ai) for ai in dst_axes) else: dst_axes = src_axes permutation = None dst_shape = tuple(src_shape[ai] for ai in dst_axes) src_shape = tuple(src_shape[ai] for ai in src_axes) data = (permutation, src_axes, dst_axes, src_shape, dst_shape) transpose_info.append(data) src_axes = dst_axes transpose_info = tuple(transpose_info) return transpose_info @classmethod def _permute_transform_info(cls, transform_info, transpose_info): assert len(transform_info) == len(transpose_info) transform_info = list(transform_info) for i, (transpose, transform) in enumerate(zip(transpose_info, transform_info)): (_, _, dst_axes, _, transpose_out_shape) = transpose (_1, _2, (src_shape, _3, src_view), (dst_shape, _4, dst_view)) = transform permuted_src_shape = tuple(src_shape[ai] for ai in dst_axes) permuted_src_view = tuple(src_view[ai] for ai in dst_axes) permuted_dst_shape = tuple(dst_shape[ai] for ai in dst_axes) permuted_dst_view = tuple(dst_view[ai] for ai in dst_axes) assert permuted_src_shape == transpose_out_shape transform = ( _1, _2, (permuted_src_shape, _3, permuted_src_view), (permuted_dst_shape, _4, permuted_dst_view), ) transform_info[i] = transform transform_info = tuple(transform_info) return transform_info
[docs] @_discretized def get_mem_requests(self, **kwds): # first we need to find out src and dst buffers for transforms (B0 and B1) nbytes = 0 for ( _, _, (src_shape, src_dtype, src_view), (dst_shape, dst_dtype, dst_view), ) in self._transform_info: nbytes = max(nbytes, compute_nbytes(src_shape, src_dtype)) nbytes = max(nbytes, compute_nbytes(dst_shape, dst_dtype)) nbytes = max(nbytes, compute_nbytes(self.input_shape, self.input_dtype)) nbytes = max(nbytes, compute_nbytes(self.output_shape, self.output_dtype)) # Then we need to find out the size of an additional tmp buffer # we can only do it by creating temporary plans prior to setup # with temporary buffers. tmp_nbytes = 0 tg = self.transform_group src = tg.FFTI.backend.empty( shape=(nbytes,), dtype=np.uint8, min_alignment=tg.op.min_fft_alignment ) dst = tg.FFTI.backend.empty( shape=(nbytes,), dtype=np.uint8, min_alignment=tg.op.min_fft_alignment ) queue = tg.FFTI.new_queue(tg=tg, name="tmp_queue") for ( _, tr, (src_shape, src_dtype, src_view), (dst_shape, dst_dtype, dst_view), ) in self._transform_info: src_nbytes = compute_nbytes(src_shape, src_dtype) dst_nbytes = compute_nbytes(dst_shape, dst_dtype) b0 = src[:src_nbytes].view(dtype=src_dtype).reshape(src_shape) b1 = dst[:dst_nbytes].view(dtype=dst_dtype).reshape(dst_shape) fft_plan = tg.FFTI.get_transform(tr)( a=b0.handle, out=b1.handle, axis=self.field.dim - 1, verbose=False ) fft_plan.setup(queue=queue) tmp_nbytes = max(tmp_nbytes, fft_plan.required_buffer_size) del src del dst if tmp_nbytes > nbytes: msg = "Planner claims to need more than buffer bytes as temporary buffer:" msg += f"\n *Buffer bytes: {bytes2str(nbytes)}" msg += f"\n *Tmp bytes: {bytes2str(tmp_nbytes)}" warnings.warn(msg, HysopFFTWarning) backend = self.transform_group.backend mem_tag = self.transform_group.mem_tag field_tag = self.dfield.name kind = backend.kind B0_tag = f"{mem_tag}_{kind}_B0" B1_tag = f"{mem_tag}_{kind}_B1" TMP_tag = f"{mem_tag}_{kind}_TMP" ENERGY_tag = f"{mem_tag}_{kind}_ENERGY" MUTEXES_tag = f"{mem_tag}_{kind}_MUTEXES" self.B0_tag, self.B1_tag, self.TMP_tag, self.ENERGY_tag, self.MUTEXES_tag = ( B0_tag, B1_tag, TMP_tag, ENERGY_tag, MUTEXES_tag, ) requests = {B0_tag: nbytes, B1_tag: nbytes, TMP_tag: tmp_nbytes} if self._do_compute_energy: if self._energy_nbytes > 0: requests[ENERGY_tag] = self._energy_nbytes if self._mutexes_nbytes > 0: requests[MUTEXES_tag] = self._mutexes_nbytes return requests
[docs] @_discretized def setup(self, work): SETUP_DEBUG = False assert not self.ready dim = self.field.dim op = self.op tg = self.transform_group FFTI = tg.FFTI is_forward = self.is_forward is_backward = self.is_backward ntransforms = self._ntransforms transform_info = self._transform_info transpose_info = self._transpose_info B0_tag, B1_tag = self.B0_tag, self.B1_tag TMP_tag = self.TMP_tag ENERGY_tag = self.ENERGY_tag MUTEXES_tag = self.MUTEXES_tag # get temporary buffers (B0,) = work.get_buffer(op, B0_tag, handle=True) (B1,) = work.get_buffer(op, B1_tag, handle=True) assert is_byte_aligned(B0) assert is_byte_aligned(B1) try: (TMP,) = work.get_buffer(op, TMP_tag, handle=True) except ValueError: TMP = None if (self._energy_nbytes is not None) and (self._energy_nbytes > 0): (ENERGY,) = work.get_buffer(op, ENERGY_tag, handle=True) energy_buffer = ENERGY[: self._energy_nbytes].view(dtype=self.dfield.dtype) assert energy_buffer.size == self._max_wavenumber + 1 else: ENERGY = None energy_buffer = None if (self._mutexes_nbytes is not None) and (self._mutexes_nbytes > 0): (MUTEXES,) = work.get_buffer(op, MUTEXES_tag, handle=True) mutexes_buffer = MUTEXES[: self._mutexes_nbytes].view(dtype=np.int32) assert mutexes_buffer.size == self._max_wavenumber + 1 else: MUTEXES = None mutexes_buffer = None # Bind transformed field buffer to input or output. # This only happens if the user did not bind another buffer prior to the setup. dfield = self.dfield if is_forward and (self._input_buffer is None): self.configure_input_buffer(dfield.sbuffer[dfield.compute_slices]) elif is_backward and (self._output_buffer is None): self.configure_output_buffer(dfield.sbuffer[dfield.compute_slices]) # bind group buffer to input or output if required. custom_input_buffer = self._custom_input_buffer custom_output_buffer = self._custom_output_buffer if is_forward and custom_output_buffer: if custom_output_buffer == "auto": # will be determined and set later pass elif custom_output_buffer == "B0": self.configure_output_buffer(B0) elif custom_output_buffer == "B1": self.configure_output_buffer(B1) else: msg = f"Unknown custom output buffer {custom_output_buffer}." raise NotImplementedError(msg) if is_backward and custom_input_buffer: if custom_input_buffer == "auto": assert self._matching_forward_transform.ready custom_input_buffer = ( self._matching_forward_transform._custom_output_buffer ) assert custom_input_buffer in ("B0", "B1") if custom_input_buffer == "B0": self.configure_input_buffer(B0) elif custom_input_buffer == "B1": self.configure_input_buffer(B1) else: msg = f"Unknown custom input buffer {custom_input_buffer}." raise NotImplementedError(msg) # define input and output buffer, as well as tmp buffers src_buffer, dst_buffer = B0, B1 def nameof(buf): assert (buf is B0) or (buf is B1) if buf is B0: return "B0" else: return "B1" def check_size(buf, nbytes, name): if buf.nbytes < nbytes: msg = "Insufficient buffer size for buffer {} (shape={}, dtype={}).".format( name, buf.shape, buf.dtype ) msg += f"\nExpected at least {nbytes} bytes but got {buf.nbytes}." try: bname = nameof(buf) msg += f"\nThis buffer has been identified as {bname}." except: pass raise RuntimeError(msg) # build spectral transform execution queue qname = "fft_planner_{}_{}".format( self.field.name, "forward" if is_forward else "backward" ) queue = FFTI.new_queue(tg=self, name=qname) if SETUP_DEBUG: def print_op(description, category): prefix = " |> " print(f"{prefix}{description: <40}[{category}]") msg = """ SPECTRAL TRANSFORM SETUP op: {} dim: {} ntransforms: {} group_tag: {} is_forward: {} is_backward: {}""".format( op.pretty_tag, dim, ntransforms, self.tag, is_forward, is_backward ) print(msg) fft_plans = () for i in range(ntransforms): transpose = transpose_info[i] transform = transform_info[i] (permutation, _, _, input_shape, output_shape) = transpose ( _, tr, (src_shape, src_dtype, src_view), (dst_shape, dst_dtype, dst_view), ) = transform assert not STU.is_none(tr), "Got a NONE transform type." is_first = i == 0 is_last = i == ntransforms - 1 should_forward_permute = is_forward and (permutation is not None) should_backward_permute = is_backward and (permutation is not None) if SETUP_DEBUG: msg = f" TRANSFORM INDEX {i}:" if permutation is not None: msg += """ Transpose Info: permutation: {} input_shape: {} output_shape: {} forward_permute: {} backward_permute: {}""".format( permutation, input_shape, output_shape, should_forward_permute, should_backward_permute, ) msg += """ Custom buffers: custom_input: {} custom output: {} Transform Info: SRC: shape {} and type {}, view {} DST: shape {} and type {}, view {} Planned Operations:""".format( custom_input_buffer, custom_output_buffer, src_shape, src_dtype, src_view, dst_shape, dst_dtype, dst_view, ) print(msg) src_nbytes = compute_nbytes(src_shape, src_dtype) dst_nbytes = compute_nbytes(dst_shape, dst_dtype) # build forward permutation if required # (forward transforms transpose before actual transforms) if should_forward_permute: input_nbytes = compute_nbytes(input_shape, src_dtype) output_nbytes = compute_nbytes(output_shape, src_dtype) assert ( output_shape == src_shape ), "Transpose to Transform shape mismatch." assert ( input_nbytes == output_nbytes ), "Transpose input and output size mismatch." assert ( src_buffer.nbytes >= input_nbytes ), "Insufficient buffer size for src buf." assert ( dst_buffer.nbytes >= output_nbytes ), "Insufficient buffer size for dst buf." if is_first: assert ( self.input_buffer.shape == input_shape ), "input_buffer shape mismatch." assert ( self.input_buffer.dtype == src_dtype ), "input_buffer dtype mismatch." b0 = self.get_mapped_input_buffer else: b0 = ( src_buffer[:input_nbytes] .view(dtype=src_dtype) .reshape(input_shape) ) b1 = ( dst_buffer[:output_nbytes] .view(dtype=src_dtype) .reshape(output_shape) ) queue += FFTI.plan_transpose(tg=tg, src=b0, dst=b1, axes=permutation) if SETUP_DEBUG: sfrom = "input_buffer" if is_first else nameof(src_buffer) sto = nameof(dst_buffer) print_op( f"PlanTranspose(src={sfrom}, dst={sto}, permutation={permutation})", "forward permute", ) src_buffer, dst_buffer = dst_buffer, src_buffer elif is_first: assert ( self.input_buffer.shape == src_shape ), "input buffer shape mismatch." assert ( self.input_buffer.dtype == src_dtype ), "input buffer dtype mismatch." assert ( src_buffer.nbytes >= src_nbytes ), "Insufficient buffer size for src buf." if (custom_input_buffer is not None) and ( nameof(src_buffer) == custom_input_buffer ): src_buffer, dst_buffer = dst_buffer, src_buffer b0 = src_buffer[:src_nbytes].view(dtype=src_dtype).reshape(src_shape) queue += FFTI.plan_copy(tg=tg, src=self.get_mapped_input_buffer, dst=b0) if SETUP_DEBUG: sfrom = "input_buffer" sto = nameof(src_buffer) print_op(f"PlanCopy(src={sfrom}, dst={sto})", "pre-transform copy") # build batched 1D transform in contiguous axis check_size(src_buffer, src_nbytes, "src") check_size(dst_buffer, dst_nbytes, "dst") b0 = src_buffer[:src_nbytes].view(dtype=src_dtype).reshape(src_shape) b1 = dst_buffer[:dst_nbytes].view(dtype=dst_dtype).reshape(dst_shape) fft_plan = FFTI.get_transform(tr)(a=b0, out=b1, axis=dim - 1) fft_plan.setup(queue=queue) fft_plans += (fft_plan,) queue += fft_plan if SETUP_DEBUG: sfrom = nameof(src_buffer) sto = nameof(dst_buffer) print_op(f"PlanTransform(src={sfrom}, dst={sto})", tr) src_buffer, dst_buffer = dst_buffer, src_buffer # build backward permutation if required # (backward transforms transpose after actual transforms) if should_backward_permute: input_nbytes = compute_nbytes(input_shape, dst_dtype) output_nbytes = compute_nbytes(output_shape, dst_dtype) assert ( input_shape == dst_shape ), "Transform to Transpose shape mismatch." assert ( input_nbytes == output_nbytes ), "Transpose input and output size mismatch." assert ( src_buffer.nbytes >= input_nbytes ), "Insufficient buffer size for src buf." assert ( dst_buffer.nbytes >= output_nbytes ), "Insufficient buffer size for dst buf." b0 = ( src_buffer[:input_nbytes].view(dtype=dst_dtype).reshape(input_shape) ) if is_last and (self._action is SpectralTransformAction.OVERWRITE): assert ( self.output_buffer.shape == output_shape ), "output buffer shape mismatch." assert ( self.output_buffer.dtype == dst_dtype ), "output buffer dtype mismatch." b1 = self.get_mapped_output_buffer else: b1 = ( dst_buffer[:output_nbytes] .view(dtype=dst_dtype) .reshape(output_shape) ) queue += FFTI.plan_transpose(tg=tg, src=b0, dst=b1, axes=permutation) if SETUP_DEBUG: sfrom = nameof(src_buffer) sto = "output_buffer" if is_last else nameof(dst_buffer) print_op( f"PlanTranspose(src={sfrom}, dst={sto})", "backward permute" ) src_buffer, dst_buffer = dst_buffer, src_buffer if is_last and (self._action is not SpectralTransformAction.OVERWRITE): if self._action is SpectralTransformAction.ACCUMULATE: assert ( self.output_buffer.shape == output_shape ), "output buffer shape mismatch." assert ( self.output_buffer.dtype == dst_dtype ), "output buffer dtype mismatch." queue += FFTI.plan_accumulate( tg=tg, src=b1, dst=self.get_mapped_output_buffer ) if SETUP_DEBUG: sfrom = nameof(dst_buffer) sto = "output_buffer" print_op( f"PlanAccumulate(src={sfrom}, dst={sto})", "post-transform accumulate", ) else: msg = f"Unsupported action {self._action}." raise NotImplementedError(msg) elif is_last: if custom_output_buffer is not None: if custom_output_buffer not in ("B0", "B1", "auto"): msg = f"Unknown custom output buffer {custom_output_buffer}." raise NotImplementedError(msg) elif custom_output_buffer == "auto": custom_output_buffer = nameof(dst_buffer) self._custom_output_buffer = custom_output_buffer if custom_output_buffer == "B0": self.configure_output_buffer(B0) elif custom_output_buffer == "B1": self.configure_output_buffer(B1) else: raise RuntimeError elif nameof(src_buffer) == custom_output_buffer: # This is a special case where we need to copy back and forth # (because of offsets) b0 = ( src_buffer[:dst_nbytes] .view(dtype=dst_dtype) .reshape(dst_shape) ) b1 = ( dst_buffer[:dst_nbytes] .view(dtype=dst_dtype) .reshape(dst_shape) ) queue += FFTI.plan_copy(tg=tg, src=b0, dst=b1) if SETUP_DEBUG: sfrom = nameof(src_buffer) sto = nameof(dst_buffer) print_op( f"PlanCopy(src={sfrom}, dst={sto})", "post-transform copy", ) src_buffer, dst_buffer = dst_buffer, src_buffer assert ( self.output_buffer.shape == dst_shape ), "output buffer shape mismatch." assert ( self.output_buffer.dtype == dst_dtype ), "output buffer dtype mismatch." assert ( src_buffer.nbytes >= dst_nbytes ), "Insufficient buffer size for src buf." b0 = src_buffer[:dst_nbytes].view(dtype=dst_dtype).reshape(dst_shape) if self._action is SpectralTransformAction.OVERWRITE: pname = "PlanCopy" pdes = "post-transform-copy" queue += FFTI.plan_copy( tg=tg, src=b0, dst=self.get_mapped_output_buffer ) elif self._action is SpectralTransformAction.ACCUMULATE: pname = "PlanAccumulate" pdes = "post-transform-accumulate" queue += FFTI.plan_accumulate( tg=tg, src=b0, dst=self.get_mapped_output_buffer ) else: msg = f"Unsupported action {self._action}." raise NotImplementedError(msg) if SETUP_DEBUG: sfrom = nameof(src_buffer) sto = ( "output_buffer" if (custom_output_buffer is None) else custom_output_buffer ) print_op(f"{pname}(src={sfrom}, dst={sto})", pdes) if self._zero_fill_output_slices: buf = self.get_mapped_full_output_buffer slcs = self._zero_fill_output_slices queue += FFTI.plan_fill_zeros(tg=tg, a=buf, slices=slcs) if SETUP_DEBUG: print_op("PlanFillZeros(dst=output_buffer)", "post-transform-callback") # allocate fft plans FFTI.allocate_plans(op, fft_plans, tmp_buffer=TMP) # build kernels to compute energy if required if self._do_compute_energy: field_buffer = self.input_buffer if self.is_forward else self.output_buffer spectral_buffer = ( self.output_buffer if self.is_forward else self.input_buffer ) compute_energy_queue = FFTI.new_queue(tg=self, name="dump_energy") compute_energy_queue += FFTI.plan_fill_zeros( tg=tg, a=energy_buffer, slices=(Ellipsis,) ) if mutexes_buffer is not None: unlock_mutexes = FFTI.plan_fill_zeros( tg=tg, a=mutexes_buffer, slices=(Ellipsis,) ) compute_energy_queue += unlock_mutexes compute_energy_queue().wait() # we need this before compute energy to unlock mutexes compute_energy_queue += FFTI.plan_compute_energy( tg=tg, fshape=field_buffer.shape, src=spectral_buffer, dst=energy_buffer, transforms=self._permuted_transforms, mutexes=mutexes_buffer, ) compute_energy_queue += FFTI.plan_copy( tg=tg, src=energy_buffer, dst=self._energy_parameter._value ) else: compute_energy_queue = None self._frequency_ioparams = tuple( self.io_params.clone(frequency=f, with_last=True) for f in self._compute_energy_frequencies ) self._queue = queue self._compute_energy_queue = compute_energy_queue self._ready = True
def __call__(self, **kwds): assert self._ready assert self._queue is not None evt = self._pre_transform_actions(**kwds) evt = self._queue.execute(wait_for=evt) evt = self._post_transform_actions(wait_for=evt, **kwds) return evt def _pre_transform_actions(self, simulation=None, wait_for=None, **kwds): evt = wait_for if simulation is False: return evt if self.is_backward and self._do_compute_energy: evt = self.compute_energy(simulation=simulation, wait_for=evt) if self._do_plot_energy: evt = self.plot_energy(simulation=simulation, wait_for=evt) return evt def _post_transform_actions(self, simulation=None, wait_for=None, **kwds): evt = wait_for if simulation is False: return evt if self.is_forward and self._do_compute_energy: evt = self.compute_energy(simulation=simulation, wait_for=evt) if self._do_plot_energy: evt = self.plot_energy(simulation=simulation, wait_for=evt) return evt
[docs] def compute_energy(self, simulation, wait_for): msg = f"No simulation was passed in {type(self)}.__call__()." assert simulation is not None, msg evt = wait_for should_compute_energy = any( iop.should_dump(simulation=simulation) for iop in self._frequency_ioparams ) if should_compute_energy: evt = self._compute_energy_queue(wait_for=evt) if self._do_dump_energy: self._energy_dumper.update(simulation=simulation, wait_for=evt) return evt
[docs] def plot_energy(self, simulation, wait_for): msg = f"No simulation was passed in {type(self)}.__call__()." assert simulation is not None, msg evt = wait_for self._energy_plotter.update(simulation=simulation, wait_for=evt) return wait_for